-
Notifications
You must be signed in to change notification settings - Fork 156
Tweaks to reshape Ops #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tweaks to reshape Ops #1842
Conversation
03342fc to
e8b8fb6
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved with small suggestions.
I still prefer a list of axes as the public API for join_dims over start_axis, n_axes, but I'm not willing to have a fight about it.
Counterargument is that for the internal uses we do have it was more awkward to do it like this and and you can't really do non-consecutive axis so there's also no benefit. So we were asking users to transform their range into a list of axis, just so we can go and undo that because we want the range anyway |
|
Yeah you are right, that's why I don't want to fight. But I think from a user perspective, it's more natural to pass the list. I'm remembering how confused I was by the |
|
What about split? Seems completely different thing |
|
It requires two arguments, splits_size and n_splits, when I would only have expected one. Similarly, I would only expect one argument for But I looked it up to type this message and I see that |
|
n_splits was needed before static shape, as It could have been avoided by expecting a sequence of scalars but they didn't go with that. |
|
For join_dims I think list of axis is fine if we support arbitrary axis. I strongly believe that if we ask for sequence of axis the user will expect they can be non-consecutive. So the API immediately forces the user to think about the valid cases. It's a range. We could rename it to join_dims_range? I'm still partial to a more flexible join_dims where we do the transpose to align the the dims to be joined for the user. But note that wasn't needed for our purposes of pack. Although then you need to decide where they end up, as there's no longer a reasonable default). xarray puts it at the end with stack but positions don't matter as much for them. |
|
Ah forgot to say why I did it though. You can't obtain the expand_dims edge case with the axis argument. There's no equivalent for Unless you offer the output_axis argument that decides where the joined axis goes |
|
I think this is fine as-is. We can revisit it if we really want a more complex operation. I think re-arranging the dimensions is out of scope for this Op. We're slowly just building back up to dimshuffle. |
It would be done by the helper, not the Op, So still combining DimShuffle and JoinDims, just doing the boilerplate work for the user. API would be something like: #1844 Agree with leaving it to another PR, but we have to decide on it soon, as it will be breaking change (this one already is) |
Also: * Allow default `None` on unpack
55c6e19 to
b7c3a87
Compare
tests/tensor/test_reshape.py
Outdated
| 2, | ||
| 1, | ||
| 3, | ||
| 4, | ||
| 5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
peak ruff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just need to take the last comma
Mainly, joining 0 axes is equivalent to inserting a new dimension. This is the mirror of how splitting a single axis into an empty shape is equivalent to squeezing it.
b7c3a87 to
802f251
Compare
(join|split)_dims
I've changed
join_dims, to be a true mirror ofsplit_dims. TheJoinDimsOp itself already was, and if we make the helper also behave like theOpwe can simplify logic elsewhere.The signature is now
join(x, axis: int=0, n_axes: int | None = None).The main change is that:
expand_dims. This is the mirror ofsplit_dims(x, split_shape=())implying asqueeze.Also:
split_dimsaxis no longer accepts None, the default is 0.It also has the pleasant side-effect that you can't specify non-consecutive axis, which the other syntax would suggest is possible (before erroring out). You can only fail with axis or n_axes too large.
(un)pack
Rename (un)pack axes argument to
keep_axes.Also:
Noneon unpackAllow default ofRevertedNoneto work even with >1d inputsCherry picked from #1806
Closes #1835